{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 是否已整理好数据集\n",
    "org = True\n",
    "if not org:\n",
    "    import os\n",
    "    import shutil\n",
    "\n",
    "    def reorg_cifar10_data(data_dir, label_file, train_dir, input_dir, valid_ratio):\n",
    "        with open(os.path.join(data_dir, label_file), 'r') as f:\n",
    "            lines = f.readlines()[1:]\n",
    "            tokens = [l.rstrip().split(',') for l in lines]\n",
    "            idx_label = dict(((int(idx), label) for idx, label in tokens))\n",
    "        labels = set(idx_label.values())\n",
    "        num_train = len(os.listdir(os.path.join(data_dir, train_dir)))\n",
    "        num_train_tuning = int(num_train * (1 - valid_ratio))\n",
    "        assert 0 < num_train_tuning < num_train\n",
    "        num_train_tuning_per_label = num_train_tuning // len(labels)\n",
    "        label_count = dict()\n",
    "        def mkdir_if_not_exist(path):\n",
    "            if not os.path.exists(os.path.join(*path)):\n",
    "                os.makedirs(os.path.join(*path))\n",
    "        for train_file in os.listdir(os.path.join(data_dir, train_dir)):\n",
    "            idx = int(train_file.split('.')[0])\n",
    "            label = idx_label[idx]\n",
    "            mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])\n",
    "            shutil.copy(os.path.join(data_dir, train_dir, train_file),os.path.join(data_dir, input_dir, 'train_valid', label))\n",
    "            if label not in label_count or label_count[label] < num_train_tuning_per_label:\n",
    "                mkdir_if_not_exist([data_dir, input_dir, 'train', label])\n",
    "                shutil.copy(os.path.join(data_dir, train_dir, train_file),os.path.join(data_dir, input_dir, 'train', label))\n",
    "                label_count[label] = label_count.get(label, 0) + 1\n",
    "            else:\n",
    "                mkdir_if_not_exist([data_dir, input_dir, 'valid', label])\n",
    "                shutil.copy(os.path.join(data_dir, train_dir, train_file),os.path.join(data_dir, input_dir, 'valid', label))\n",
    "\n",
    "\n",
    "    train_dir = 'train'\n",
    "\n",
    "    data_dir = '/Users/Sinyer/Desktop/kaggle'\n",
    "    label_file = 'trainLabels.csv'\n",
    "    input_dir = 'train_valid_test'\n",
    "    valid_ratio = 0.1\n",
    "    reorg_cifar10_data(data_dir, label_file, train_dir, input_dir, valid_ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import init, gluon, nd, autograd, image\n",
    "from mxnet.gluon import nn\n",
    "from mxnet.gluon.data import vision\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from time import time\n",
    "ctx = mx.gpu()\n",
    "data_dir = '/Users/Sinyer/Desktop/kaggle'\n",
    "label_file = 'trainLabels.csv'\n",
    "input_dir = 'train_valid_test'\n",
    "\n",
    "input_str = data_dir + '/' + input_dir + '/'\n",
    "\n",
    "train_valid_ds = vision.ImageFolderDataset(input_str + 'train_valid', flag=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "ctx = mx.cpu()\n",
    "def mesuare_softmax_sum(preds, weight_list):\n",
    "    output = nd.softmax(preds[0], axis=1) * weight_list[0]\n",
    "    for i in range(1, len(preds)):\n",
    "        output = output + nd.softmax(preds[i], axis=1) * weight_list[i]\n",
    "    preds = output.argmax(axis=1).astype(int).asnumpy() % 10\n",
    "    return preds\n",
    "\n",
    "model_list = ['resnet164_e255_focal_clip', 'res164__2_e255_focal_clip_all_data', 'resnet164_e300','resnet164_e0-255_focal_clip',\n",
    "               'shelock_densenet_orign',\n",
    "              'shelock_resnet_orign', 'wrn', 'wrn']\n",
    "weight_list = [0.9535, 0.9540, 0.95270, 0.95, 0.9539, 0.95, 0.9596, 0.9596]\n",
    "preds = []\n",
    "for result_name in model_list:\n",
    "    preds.append(nd.load(\"/Users/Sinyer/Desktop/kaggle/\"+result_name)[0].as_in_context(ctx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = mesuare_softmax_sum(preds, weight_list)\n",
    "\n",
    "sorted_ids = list(range(1, 300000 + 1))\n",
    "sorted_ids.sort(key=lambda x: str(x))\n",
    "\n",
    "df = pd.DataFrame({'id': sorted_ids, 'label': preds})\n",
    "df['label'] = df['label'].apply(lambda x: train_valid_ds.synsets[x])\n",
    "df.to_csv('/Users/Sinyer/Desktop/ensemble.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
